#include <stdio.h>
#include <rte_eal.h>
#include <rte_ethdev.h>
#include <arpa/inet.h>

int global_portid = 0;

#define NUM_MBUFS 4096
#define BURST_SIZE 128

#define ENABLE_SEND 1
#define ENABLE_TCP 1

#define TCP_INIT_WINDOWS 14600

#define SEND_UDP_HEADER_LENGTH 42
#define SEND_TCP_HEADER_LENGTH 54

#define TCP_SEND_BACK_DATA 0

#define UDP 0

#if ENABLE_SEND
uint8_t global_smac[RTE_ETHER_ADDR_LEN];
uint8_t global_dmac[RTE_ETHER_ADDR_LEN];
uint32_t global_sip;
uint32_t global_dip;
uint16_t global_sport;
uint16_t global_dport;
#endif

#if ENABLE_TCP
uint8_t global_flags;
uint32_t global_d_seqnum;         // 对端发送序列号（无符号）
uint32_t global_d_acknum;         // 对端发送确认号（无符号）
uint32_t global_s_seqnum = 12345; // 本地发送序列号（无符号）
uint32_t global_s_acknum;         // 本地发送确认号（无符号）

typedef enum __USTACK_TCP_STATUS
{
  USTACK_TCP_STATUS_CLOSED = 0,
  USTACK_TCP_STATUS_LISTEN,
  USTACK_TCP_STATUS_SYN_RCVD,
  USTACK_TCP_STATUS_SYN_SENT,
  USTACK_TCP_STATUS_ESTABLISHED,
  USTACK_TCP_STATUS_FIN_WAIT_1,
  USTACK_TCP_STATUS_FIN_WAIT_2,
  USTACK_TCP_STATUS_CLOSING,
  USTACK_TCP_STATUS_TIMEWAIT,
  USTACK_TCP_STATUS_CLOSE_WAIT,
  USTACK_TCP_STATUS_LAST_ACK
} USTACK_TCP_STATUS;

uint8_t tcp_status = USTACK_TCP_STATUS_LISTEN;
#endif

static const struct rte_eth_conf port_conf_default = {
    .rxmode = {.max_rx_pkt_len = RTE_ETHER_MAX_LEN}};

static int ustack_init_port(struct rte_mempool *mbuf_pool)
{
  uint16_t nb_sys_ports = rte_eth_dev_count_avail();
  if (nb_sys_ports == 0)
  {
    rte_exit(EXIT_FAILURE, "No Supported eth found\n");
  }

  struct rte_eth_dev_info dev_info;
  rte_eth_dev_info_get(global_portid, &dev_info);

  const int num_rx_queues = 1;
#if ENABLE_SEND
  const int num_tx_queues = 1;
#else
  const int num_tx_queues = 0;
#endif
  rte_eth_dev_configure(global_portid, num_rx_queues, num_tx_queues, &port_conf_default);

  if (rte_eth_rx_queue_setup(global_portid, 0, 128,
                             rte_eth_dev_socket_id(global_portid), NULL, mbuf_pool) < 0)
  {
    rte_exit(EXIT_FAILURE, "Could not setup RX queue\n");
  }

#if ENABLE_SEND
  struct rte_eth_txconf txq_conf = dev_info.default_txconf;
  txq_conf.offloads = port_conf_default.rxmode.offloads;
  if (rte_eth_tx_queue_setup(global_portid, 0, 512,
                             rte_eth_dev_socket_id(global_portid), &txq_conf) < 0)
  {
    rte_exit(EXIT_FAILURE, "Could not setup TX queue\n");
  }
#endif

  if (rte_eth_dev_start(global_portid) < 0)
  {
    rte_exit(EXIT_FAILURE, "Could not start\n");
  }
  return 0;
}

#if UDP
static int ustack_encode_udp_pkt(uint8_t *msg, uint8_t *data, uint16_t total_len)
{
  struct rte_ether_hdr *eth = (struct rte_ether_hdr *)msg;
  rte_memcpy(eth->d_addr.addr_bytes, global_dmac, RTE_ETHER_ADDR_LEN);
  rte_memcpy(eth->s_addr.addr_bytes, global_smac, RTE_ETHER_ADDR_LEN);
  eth->ether_type = htons(RTE_ETHER_TYPE_IPV4);

  struct rte_ipv4_hdr *ip = (struct rte_ipv4_hdr *)(eth + 1);
  ip->version_ihl = 0x45;
  ip->type_of_service = 0;
  ip->total_length = htons(total_len - sizeof(struct rte_ether_hdr));
  ip->packet_id = 0;
  ip->fragment_offset = 0;
  ip->time_to_live = 64;
  ip->next_proto_id = IPPROTO_UDP;
  ip->src_addr = global_sip;
  ip->dst_addr = global_dip;

  ip->hdr_checksum = 0;
  ip->hdr_checksum = rte_ipv4_cksum(ip);

  struct rte_udp_hdr *udp = (struct rte_udp_hdr *)(ip + 1);
  udp->src_port = global_sport;
  udp->dst_port = global_dport;
  uint16_t udplen = total_len - sizeof(struct rte_ether_hdr) - sizeof(struct rte_ipv4_hdr);
  udp->dgram_len = htons(udplen);

  rte_memcpy((uint8_t *)(udp + 1), data, udplen);
  udp->dgram_cksum = 0;
  udp->dgram_cksum = rte_ipv4_udptcp_cksum(ip, udp);
  return 0;
}
#endif

static int ustack_encode_tcp_pkt(uint8_t *msg, uint16_t total_len, uint8_t *data, uint16_t data_len, unsigned char flag)
{
  struct rte_ether_hdr *eth = (struct rte_ether_hdr *)msg;
  rte_memcpy(eth->d_addr.addr_bytes, global_dmac, RTE_ETHER_ADDR_LEN);
  rte_memcpy(eth->s_addr.addr_bytes, global_smac, RTE_ETHER_ADDR_LEN);
  eth->ether_type = htons(RTE_ETHER_TYPE_IPV4);

  struct rte_ipv4_hdr *ip = (struct rte_ipv4_hdr *)(eth + 1);
  ip->version_ihl = 0x45;
  ip->type_of_service = 0;
  ip->total_length = htons(total_len - sizeof(struct rte_ether_hdr));
  ip->packet_id = 0;
  ip->fragment_offset = 0;
  ip->time_to_live = 64;
  ip->next_proto_id = IPPROTO_TCP;
  ip->src_addr = global_sip;
  ip->dst_addr = global_dip;

  ip->hdr_checksum = 0;
  ip->hdr_checksum = rte_ipv4_cksum(ip);

  struct rte_tcp_hdr *tcp = (struct rte_tcp_hdr *)(ip + 1);
  tcp->src_port = global_sport;
  tcp->dst_port = global_dport;
  tcp->sent_seq = htonl(global_s_seqnum);
  tcp->recv_ack = htonl(global_s_acknum);
  tcp->data_off = 0x50;

  if (flag == 0)
  {
    tcp->tcp_flags = RTE_TCP_ACK_FLAG | RTE_TCP_SYN_FLAG;
  }
  else if (flag == 1)
  {
    tcp->tcp_flags = RTE_TCP_ACK_FLAG;
  }
  else if (flag == 2)
  {
    tcp->tcp_flags = RTE_TCP_ACK_FLAG | RTE_TCP_PSH_FLAG;
  }

  tcp->rx_win = htons(TCP_INIT_WINDOWS);
  tcp->cksum = 0;
  tcp->cksum = rte_ipv4_udptcp_cksum(ip, tcp);

  if (data != NULL && data_len > 0)
  {
    rte_memcpy((uint8_t *)(tcp + 1), data, data_len);
  }


  return 0;
}

// MAC地址格式化打印辅助函数
static void print_mac(const char *name, uint8_t *mac)
{
  printf("%s: %02x:%02x:%02x:%02x:%02x:%02x ",
         name, mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
}

int main(int argc, char *argv[])
{
  if (rte_eal_init(argc, argv) < 0)
  {
    rte_exit(EXIT_FAILURE, "Error with EAL init\n");
  }

  struct rte_mempool *mbuf_pool = rte_pktmbuf_pool_create("mbuf pool", NUM_MBUFS, 0, 0, RTE_MBUF_DEFAULT_BUF_SIZE, rte_socket_id());
  if (mbuf_pool == NULL)
  {
    rte_exit(EXIT_FAILURE, "Could not create mbuf pool\n");
  }

  ustack_init_port(mbuf_pool);

  while (1)
  {
    struct rte_mbuf *mbufs[BURST_SIZE] = {0};
    uint16_t num_recvd = rte_eth_rx_burst(global_portid, 0, mbufs, BURST_SIZE);
    if (num_recvd > BURST_SIZE)
    {
      rte_exit(EXIT_FAILURE, "Error receiving from eth\n");
    }

    for (int i = 0; i < num_recvd; i++)
    {
      struct rte_ether_hdr *ethhdr = rte_pktmbuf_mtod(mbufs[i], struct rte_ether_hdr *);
      if (ethhdr->ether_type != rte_cpu_to_be_16(RTE_ETHER_TYPE_IPV4))
      {
        continue;
      }

      struct rte_ipv4_hdr *iphdr = rte_pktmbuf_mtod_offset(mbufs[i], struct rte_ipv4_hdr *, sizeof(struct rte_ether_hdr));
      uint8_t ip_hdr_len = (iphdr->version_ihl & 0x0F) * 4;

      if (iphdr->next_proto_id == IPPROTO_UDP)
      {
#if UDP
        struct rte_udp_hdr *udphdr = (struct rte_udp_hdr *)(iphdr + 1);

        // 新增：UDP接收时打印MAC、IP、端口
        printf("[UDP接收] ");
        print_mac("源MAC", ethhdr->s_addr.addr_bytes);   // 接收包的源MAC是发送端MAC
        print_mac("目的MAC", ethhdr->d_addr.addr_bytes); // 接收包的目的MAC是本地MAC
        struct in_addr addr;
        addr.s_addr = iphdr->src_addr;
        printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(udphdr->src_port));
        addr.s_addr = iphdr->dst_addr;
        printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(udphdr->dst_port));
        printf("recv_udp: %s\n\n", (char *)(udphdr + 1));

#if ENABLE_SEND
        rte_memcpy(global_smac, ethhdr->d_addr.addr_bytes, RTE_ETHER_ADDR_LEN);
        rte_memcpy(global_dmac, ethhdr->s_addr.addr_bytes, RTE_ETHER_ADDR_LEN);
        rte_memcpy(&global_sip, &iphdr->dst_addr, sizeof(uint32_t));
        rte_memcpy(&global_dip, &iphdr->src_addr, sizeof(uint32_t));
        rte_memcpy(&global_sport, &udphdr->dst_port, sizeof(uint16_t));
        rte_memcpy(&global_dport, &udphdr->src_port, sizeof(uint16_t));
#if 0
        addr.s_addr = iphdr->src_addr;
        printf("sip %s:%d --> ", inet_ntoa(addr), ntohs(udphdr->src_port));
        addr.s_addr = iphdr->dst_addr;
        printf("dip %s:%d --> ", inet_ntoa(addr), ntohs(udphdr->dst_port));
#endif
        uint16_t length = ntohs(udphdr->dgram_len);
        uint16_t total_len = length + sizeof(struct rte_ipv4_hdr) + sizeof(struct rte_ether_hdr);

        struct rte_mbuf *mbuf = rte_pktmbuf_alloc(mbuf_pool);
        if (!mbuf)
        {
          rte_exit(EXIT_FAILURE, "Error rte_pktmbuf_alloc\n");
        }
        mbuf->pkt_len = total_len;
        mbuf->data_len = total_len;

        uint8_t *msg = rte_pktmbuf_mtod(mbuf, uint8_t *);
        ustack_encode_udp_pkt(msg, (uint8_t *)(udphdr + 1), total_len);

        // UDP发送前打印MAC、IP、端口
        printf("[UDP发送] ");
        print_mac("源MAC", global_smac);
        print_mac("目的MAC", global_dmac);
        addr.s_addr = global_sip;
        printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(global_sport));
        addr.s_addr = global_dip;
        printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(global_dport));

        rte_eth_tx_burst(global_portid, 0, &mbuf, 1);
#endif
        printf("send_udp : %s\n\n", (char *)(udphdr + 1));
#endif
      }
      else if (iphdr->next_proto_id == IPPROTO_TCP)
      {
        struct rte_tcp_hdr *tcphdr = (struct rte_tcp_hdr *)(iphdr + 1);
        uint8_t tcp_hdr_len = (tcphdr->data_off >> 4) * 4;
        uint16_t recv_tcp_data_len = rte_pktmbuf_data_len(mbufs[i]) -
                                     sizeof(struct rte_ether_hdr) - ip_hdr_len - tcp_hdr_len;

        rte_memcpy(global_smac, ethhdr->d_addr.addr_bytes, RTE_ETHER_ADDR_LEN);
        rte_memcpy(global_dmac, ethhdr->s_addr.addr_bytes, RTE_ETHER_ADDR_LEN);
        rte_memcpy(&global_sip, &iphdr->dst_addr, sizeof(uint32_t));
        rte_memcpy(&global_dip, &iphdr->src_addr, sizeof(uint32_t));
        rte_memcpy(&global_sport, &tcphdr->dst_port, sizeof(uint16_t));
        rte_memcpy(&global_dport, &tcphdr->src_port, sizeof(uint16_t));

        global_flags = tcphdr->tcp_flags;
        global_d_seqnum = ntohl(tcphdr->sent_seq);
        global_d_acknum = ntohl(tcphdr->recv_ack);

        // 新增：TCP接收时打印MAC、IP、端口
        printf("[TCP接收] ");
        uint8_t *recv_data = ((uint8_t *)tcphdr + tcp_hdr_len);
        print_mac("源MAC", ethhdr->s_addr.addr_bytes);
        print_mac("目的MAC", ethhdr->d_addr.addr_bytes);
        struct in_addr addr;
        addr.s_addr = iphdr->src_addr;
        printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(tcphdr->src_port));
        addr.s_addr = iphdr->dst_addr;
        printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(tcphdr->dst_port));
        printf("tcp recv data: %s, recv_data_len: %d, seqnum: %u, acknum: %u\n\n", recv_data, recv_tcp_data_len, global_d_seqnum, global_d_acknum);
#if 0
        addr.s_addr = iphdr->src_addr;
        printf("sip %s:%d --> ", inet_ntoa(addr), ntohs(tcphdr->src_port));
        addr.s_addr = iphdr->dst_addr;
        printf("dip %s:%d , flags: 0x%x, seqnum: %u, acknum: %u\n",
               inet_ntoa(addr), ntohs(tcphdr->dst_port),
               global_flags, global_d_seqnum, global_d_acknum);
#endif
        if (global_flags & RTE_TCP_SYN_FLAG)
        {
          if (tcp_status == USTACK_TCP_STATUS_LISTEN)
          {
            global_s_acknum = global_d_seqnum + 1;

            printf("[TCP SYN+ACK发送] ");
            print_mac("源MAC", global_smac);
            print_mac("目的MAC", global_dmac);
            addr.s_addr = global_sip;
            printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(global_sport));
            addr.s_addr = global_dip;
            printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(global_dport));
            printf("seq=%u, ack=%u ", global_s_seqnum, global_s_acknum);

            uint16_t total_len = sizeof(struct rte_tcp_hdr) + sizeof(struct rte_ipv4_hdr) + sizeof(struct rte_ether_hdr);
            struct rte_mbuf *mbuf = rte_pktmbuf_alloc(mbuf_pool);
            if (!mbuf)
            {
              rte_exit(EXIT_FAILURE, "Error rte_pktmbuf_alloc\n");
            }
            mbuf->pkt_len = total_len;
            mbuf->data_len = total_len;

            uint8_t *msg = rte_pktmbuf_mtod(mbuf, uint8_t *);
            ustack_encode_tcp_pkt(msg, total_len, NULL, 0, 0);

            // TCP SYN+ACK发送前打印MAC、IP、端口
            print_mac("源MAC", global_smac);
            print_mac("目的MAC", global_dmac);
            addr.s_addr = global_sip;
            printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(global_sport));
            addr.s_addr = global_dip;
            printf("目的IP:%s:%d\n\n", inet_ntoa(addr), ntohs(global_dport));

            rte_eth_tx_burst(global_portid, 0, &mbuf, 1);

            tcp_status = USTACK_TCP_STATUS_SYN_RCVD;
            global_s_seqnum++; // SYN包改变序列号
          }
        }

        if (global_flags & RTE_TCP_ACK_FLAG)
        {
          if (tcp_status == USTACK_TCP_STATUS_SYN_RCVD)
          {
            printf("enter established\n\n");
            tcp_status = USTACK_TCP_STATUS_ESTABLISHED;
          }
        }

        if (global_flags & RTE_TCP_PSH_FLAG && tcp_status == USTACK_TCP_STATUS_ESTABLISHED)
        {


#if TCP_SEND_BACK_DATA

          uint16_t send_tcp_data_len = recv_tcp_data_len;
          uint8_t *send_data = recv_data;
#else

          uint16_t send_tcp_data_len = 0;
          uint8_t *send_data = NULL;
#endif
          global_s_acknum = global_d_seqnum + recv_tcp_data_len;
        

          uint16_t total_len = sizeof(struct rte_tcp_hdr) + sizeof(struct rte_ipv4_hdr) + sizeof(struct rte_ether_hdr) + send_tcp_data_len;


          struct rte_mbuf *reply_mbuf = rte_pktmbuf_alloc(mbuf_pool);
          if (!reply_mbuf)
          {
            rte_exit(EXIT_FAILURE, "TCP reply mbuf alloc failed\n");
          }
          reply_mbuf->pkt_len = total_len;
          reply_mbuf->data_len = total_len;

          uint8_t *reply_msg = rte_pktmbuf_mtod(reply_mbuf, uint8_t *);

#if TCP_SEND_BACK_DATA

          ustack_encode_tcp_pkt(reply_msg, total_len, recv_data, send_tcp_data_len, 2);

          // TCP数据回复（无数据）发送前打印MAC、IP、端口
          printf("[TCP PSH+ACK发送] ");
          print_mac("源MAC", global_smac);
          print_mac("目的MAC", global_dmac);
          addr.s_addr = global_sip;
          printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(global_sport));
          addr.s_addr = global_dip;
          printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(global_dport));

          rte_eth_tx_burst(global_portid, 0, &reply_mbuf, 1);
          printf("send_data = %s, send_data_len = %d, seq=%u, ack=%u\n\n", send_data, send_tcp_data_len, global_s_seqnum, global_s_acknum);
          global_s_seqnum += send_tcp_data_len;

#else
          ustack_encode_tcp_pkt(reply_msg, total_len, send_data, send_tcp_data_len, 1);

          // TCP数据回复（带数据）发送前打印MAC、IP、端口
          printf("[TCP ACK发送] ");
          print_mac("源MAC", global_smac);
          print_mac("目的MAC", global_dmac);
          addr.s_addr = global_sip;
          printf("源IP:%s:%d ", inet_ntoa(addr), ntohs(global_sport));
          addr.s_addr = global_dip;
          printf("目的IP:%s:%d \n", inet_ntoa(addr), ntohs(global_dport));

          rte_eth_tx_burst(global_portid, 0, &reply_mbuf, 1);
          printf("seq=%u, ack=%u\n\n", global_s_seqnum, global_s_acknum);
          // global_s_acknum++; // ACK包不改变序列号

#endif
        }
      }
      rte_pktmbuf_free(mbufs[i]);
    }
  }
  return 0;
}